import sys
import traceback
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from numpy import linalg as la


def centered_cov_torch(x):
    n = x.shape[0]
    if n == 1:
        n += 1
        x += 0.00001
    # tmp = x.t().mm(x)
    res = 1 / (n - 1) * x.t().mm(x)
    return res


def centered_cov_numpy(x):
    n = x.shape[0]
    if n == 1:
        n += 1
        x += 0.00001
    res = 1 / (n - 1) * np.matmul(x.T, x)
    tmp3 = np.cov(x.T)
    return res


def gmm_forward(net, gaussians_model, data_B_X):
    if isinstance(net, nn.DataParallel):
        features_B_Z = net.module(data_B_X)
        features_B_Z = net.module.feature
    else:
        features_B_Z = net(data_B_X)
        features_B_Z = net.measure

    log_probs_B_Y = gaussians_model.log_prob(features_B_Z[:, None, :])

    return log_probs_B_Y


def gmm_evaluate(net, gaussians_model, loader, device, num_classes, storage_device):
    num_samples = len(loader.dataset)
    logits_N_C = torch.empty((num_samples, num_classes), dtype=torch.float, device=storage_device)
    labels_N = torch.empty(num_samples, dtype=torch.int, device=storage_device)

    with torch.no_grad():
        start = 0
        for data, label in tqdm(loader):
            data = data.to(device)
            label = label.to(device)

            logit_B_C = gmm_forward(net, gaussians_model, data)

            end = start + len(data)
            logits_N_C[start:end].copy_(logit_B_C, non_blocking=True)
            labels_N[start:end].copy_(label, non_blocking=True)
            start = end

    return logits_N_C, labels_N


def gmm_get_logits(gmm, embeddings):
    log_probs_B_Y = gmm.log_prob(embeddings[:, None, :])
    return log_probs_B_Y


def gmm_fit(embeddings, labels, class_labels=[], apply_pd=False):
    # embeddings = embeddings.cpu().detach().numpy()
    double_info = np.finfo(np.float64)
    jitters = [0, double_info.tiny] + [10 ** exp for exp in range(-308, 10, 1)]  # (-308, 0, 1)

    classwise_mean_features = np.stack(
        [np.mean(embeddings[labels == class_labels[c_idx]], axis=0)
         for c_idx in range(len(class_labels))])

    classwise_cov_features = np.stack(
        [centered_cov_numpy(embeddings[labels == class_labels[c_idx]] - classwise_mean_features[c_idx])
         for c_idx in range(len(class_labels))]
    )

    if apply_pd:
        nearestPD_cov_features_all = []
        for classwise_cov_feature in classwise_cov_features:
            nearestPD_cov_features = nearestPD(classwise_cov_feature)
            nearestPD_cov_features_all.append(nearestPD_cov_features)
        classwise_cov_features = np.stack(nearestPD_cov_features_all, axis=0)

    with torch.no_grad():
        classwise_mean_features = torch.tensor(classwise_mean_features)
        classwise_cov_features = torch.tensor(classwise_cov_features)
        for jitter_eps in jitters:
            try:
                jitter = jitter_eps * torch.eye(
                    classwise_cov_features.shape[1], device=classwise_cov_features.device,
                ).unsqueeze(0)
                gmm = torch.distributions.MultivariateNormal(
                    loc=classwise_mean_features, covariance_matrix=(classwise_cov_features + jitter),
                )
            except Exception:
                # print('Skip', jitter_eps)
                # print(traceback.format_exc())
                # print(sys.exc_info()[2])
                continue

            # tmp = nearestPD(classwise_cov_features)
            print('Apply', jitter_eps)
            return gmm, jitter_eps


def gmm_fit_torch(embeddings, labels, class_labels=[]):
    double_info = torch.finfo(torch.double)
    jitters = [0, double_info.tiny] + [10 ** exp for exp in range(-308, 10, 1)]  # (-308, 0, 1)

    with torch.no_grad():
        classwise_mean_features = torch.stack(
            [torch.mean(embeddings[labels == class_labels[c_idx]], dim=0)
             for c_idx in range(len(class_labels))])
        classwise_cov_features = torch.stack(
            [centered_cov_torch(embeddings[labels == class_labels[c_idx]] - classwise_mean_features[c_idx])
             for c_idx in range(len(class_labels))]
        )
        # tmp = classwise_mean_features.detach().numpy()
        # for tmp in classwise_cov_features:
        # print(tmp)
    with torch.no_grad():
        for jitter_eps in jitters:
            try:
                jitter = jitter_eps * torch.eye(
                    classwise_cov_features.shape[1], device=classwise_cov_features.device,
                ).unsqueeze(0)
                gmm = torch.distributions.MultivariateNormal(
                    loc=classwise_mean_features, covariance_matrix=(classwise_cov_features + jitter),
                )
            except Exception:
                # print('Skip', jitter_eps)
                # print(traceback.format_exc())
                # print(sys.exc_info()[2])
                continue
            # except RuntimeError as e:
            #     if "cholesky" in str(e):
            #         continue
            # except ValueError as e:
            #     if "The parameter covariance_matrix has invalid values" in str(e):
            #         continue

            print('Apply', jitter_eps)
            return gmm, jitter_eps


def nearestPD(A):
    """Find the nearest positive-definite matrix to input

    A Python/Numpy port of John D'Errico's `nearestSPD` MATLAB code [1], which
    credits [2].

    [1] https://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd

    [2] N.J. Higham, "Computing a nearest symmetric positive semidefinite
    matrix" (1988): https://doi.org/10.1016/0024-3795(88)90223-6
    """

    B = (A + A.T) / 2
    _, s, V = la.svd(B)

    H = np.dot(V.T, np.dot(np.diag(s), V))

    A2 = (B + H) / 2

    A3 = (A2 + A2.T) / 2

    if isPD(A3):
        return A3

    spacing = np.spacing(la.norm(A))
    # The above is different from [1]. It appears that MATLAB's `chol` Cholesky
    # decomposition will accept matrixes with exactly 0-eigenvalue, whereas
    # Numpy's will not. So where [1] uses `eps(mineig)` (where `eps` is Matlab
    # for `np.spacing`), we use the above definition. CAVEAT: our `spacing`
    # will be much larger than [1]'s `eps(mineig)`, since `mineig` is usually on
    # the order of 1e-16, and `eps(1e-16)` is on the order of 1e-34, whereas
    # `spacing` will, for Gaussian random matrixes of small dimension, be on
    # othe order of 1e-16. In practice, both ways converge, as the unit test
    # below suggests.
    I = np.eye(A.shape[0])
    k = 1
    while not isPD(A3):
        mineig = np.min(np.real(la.eigvals(A3)))
        A3 += I * (-mineig * k ** 2 + spacing)
        k += 1

    return A3


def isPD(B):
    """Returns true when input is positive-definite, via Cholesky"""
    try:
        _ = la.cholesky(B)
        return True
    except la.LinAlgError:
        return False


if __name__ == '__main__':
    import numpy as np

    for i in np.arange(10):
        for j in np.arange(2, 100):
            A = np.random.randn(j, j)
            B = nearestPD(A)
            assert (isPD(B))
    print('unit test passed!')
